{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import math\n",
    "import argparse\n",
    "from torch.autograd import Variable\n",
    "from augerino import datasets, models, losses\n",
    "import glob\n",
    "import re\n",
    "import pandas as pd\n",
    "import sys\n",
    "\n",
    "from data.generate_data import *\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "//miniconda3/lib/python3.7/site-packages/torch/nn/functional.py:3327: UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.\n",
      "  warnings.warn(\"Default grid_sample and affine_grid behavior has changed \"\n",
      "//miniconda3/lib/python3.7/site-packages/torch/nn/functional.py:3264: UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.\n",
      "  warnings.warn(\"Default grid_sample and affine_grid behavior has changed \"\n"
     ]
    }
   ],
   "source": [
    "softplus = torch.nn.Softplus()\n",
    "savedir = \"./saved-outputs/\"\n",
    "\n",
    "ntrain = 10000\n",
    "ntest = 5000\n",
    "\n",
    "trainloader, testloader = generate_mario_data(ntrain=ntrain, ntest=ntest,\n",
    "                                              batch_size=128, dpath=\"./data/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trainer(model, reg=0.01, epochs=20):\n",
    "    \n",
    "    optimizer = torch.optim.Adam(model.parameters(),lr=0.01, weight_decay=0.)\n",
    "    \n",
    "    use_cuda = torch.cuda.is_available()\n",
    "    if use_cuda:\n",
    "        model = model.cuda()\n",
    "\n",
    "    logger = []\n",
    "\n",
    "    criterion = losses.unif_aug_loss\n",
    "\n",
    "    for epoch in range(epochs):  # loop over the dataset multiple times\n",
    "        for i, data in enumerate(trainloader):\n",
    "            # get the inputs; data is a list of [inputs, labels]\n",
    "            inputs, labels = data\n",
    "\n",
    "            if use_cuda:\n",
    "                inputs, labels = inputs.cuda(), labels.cuda()\n",
    "\n",
    "            # zero the parameter gradients\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # forward + backward + optimize\n",
    "            # print(inputs.shape)\n",
    "            outputs = model(inputs)\n",
    "            loss = criterion(outputs, labels, model,\n",
    "                            reg=reg)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            log = softplus(model.aug.width).tolist()\n",
    "            log += model.aug.width.grad.data.tolist()\n",
    "            log += [loss.item()]\n",
    "            logger.append(log)\n",
    "            \n",
    "    logdf = pd.DataFrame(logger)\n",
    "    logdf.columns = ['width' + str(i) for i in range(6)] + ['grad' + str(i) for i in range(6)] + ['loss']\n",
    "    logdf = logdf.reset_index()\n",
    "    return logdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "savedir = \"../limit-invariance/saved-outputs/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-663e4a3e96a9>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0mhigh_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maug\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_width\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstart_widths\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mhigh_logger\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhigh_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreg\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-3-18fde402f464>\u001b[0m in \u001b[0;36mtrainer\u001b[0;34m(model, reg, epochs)\u001b[0m\n\u001b[1;32m     27\u001b[0m             loss = criterion(outputs, labels, model,\n\u001b[1;32m     28\u001b[0m                             reg=reg)\n\u001b[0;32m---> 29\u001b[0;31m             \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     30\u001b[0m             \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m//miniconda3/lib/python3.7/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m    196\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    197\u001b[0m         \"\"\"\n\u001b[0;32m--> 198\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    200\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m//miniconda3/lib/python3.7/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m     98\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[1;32m     99\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 100\u001b[0;31m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[1;32m    101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    102\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "net = models.SimpleConv(c=32, num_classes=4)\n",
    "augerino = models.UniformAug()\n",
    "high_model = models.AugAveragedModel(net, augerino,ncopies=1)\n",
    "\n",
    "start_widths = torch.ones(6) * -5.\n",
    "start_widths[2] = -1.\n",
    "high_model.aug.set_width(start_widths)\n",
    "\n",
    "high_logger = trainer(high_model, reg=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(high_model.state_dict(), savedir + \"highreg.pt\")\n",
    "high_logger.to_pickle(savedir + \"high_logger.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = models.SimpleConv(c=32, num_classes=4)\n",
    "augerino = models.UniformAug()\n",
    "low_model = models.AugAveragedModel(net, augerino,ncopies=1)\n",
    "\n",
    "start_widths = torch.ones(6) * -5.\n",
    "start_widths[2] = -1.\n",
    "\n",
    "low_model.aug.set_width(start_widths)\n",
    "low_logger = trainer(low_model, reg=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(low_model.state_dict(), savedir + \"lowreg.pt\")\n",
    "low_logger.to_pickle(savedir + \"low_logger.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = models.SimpleConv(c=32, num_classes=4)\n",
    "augerino = models.UniformAug()\n",
    "mid_model = models.AugAveragedModel(net, augerino,ncopies=1)\n",
    "\n",
    "start_widths = torch.ones(6) * -5.\n",
    "start_widths[2] = -1.\n",
    "\n",
    "mid_model.aug.set_width(start_widths)\n",
    "mid_logger = trainer(mid_model, reg=0.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(high_model.state_dict(), savedir + \"midreg.pt\")\n",
    "mid_logger.to_pickle(savedir + \"mid_logger.pkl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "low_logger['lowbd'] = -low_logger['width2']/2.\n",
    "low_logger['upbd'] = low_logger['width2']/2.\n",
    "high_logger['lowbd'] = -high_logger['width2']/2.\n",
    "high_logger['upbd'] = high_logger['width2']/2.\n",
    "mid_logger['lowbd'] = -mid_logger['width2']/2.\n",
    "mid_logger['upbd'] = mid_logger['width2']/2."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 0.1\n",
    "lwd = 0.\n",
    "\n",
    "def plot_shade(logger, ax, color, label=\"\"):\n",
    "    ax.fill_between(logger.index, logger['lowbd'], logger['upbd'],\n",
    "                    alpha=alpha, color=color,\n",
    "                    linewidth=lwd)\n",
    "    sns.lineplot(x=logger.index, y='lowbd', color=color, data=logger, label=label)\n",
    "    sns.lineplot(x=logger.index, y='upbd', color=color, data=logger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tick_pts = [-np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2]\n",
    "tick_labs = [r\"-$\\pi$/2\", r'-$\\pi$/4', '0', r'$\\pi$/4', r'$\\pi$/2']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax0 = plt.subplots(1, 1, figsize=(8, 4), dpi=100)\n",
    "fs = 14\n",
    "pal = sns.color_palette(\"tab10\")\n",
    "col0 = pal[0]\n",
    "col1 = pal[1]\n",
    "col2 = pal[2]\n",
    "\n",
    "plot_shade(low_logger, ax0, col0, \"Low Reg\")\n",
    "plot_shade(mid_logger, ax0, col1, \"Mid Reg\")\n",
    "plot_shade(high_logger, ax0, col2, \"High Reg\")\n",
    "\n",
    "# ax0.set_title(\"Rotation Distributions\")\n",
    "ax0.set_xlabel(\"Iteration\", fontsize=fs)\n",
    "ax0.set_ylabel(\"Rotation Width\", fontsize=fs)\n",
    "# ax0.set_title(\"CE Losses\")\n",
    "ax0.tick_params(\"both\", labelsize=fs-2)\n",
    "sns.despine()\n",
    "ax0.set_xticks([])\n",
    "ax0.set_yticks(tick_pts)\n",
    "ax0.set_yticklabels(tick_labs)\n",
    "# ax0.set_xlim(0, 500)\n",
    "# ax0.legend()\n",
    "# plt.setp(ax0.get_legend().get_texts(), fontsize=fs-4) # for legend text\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    " "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
